"""
@Time: 2021/2/5 下午 7:48
@Author: jinzhuan
@File: deepke.py
@Desc: 
"""
import os
from ..loader import Loader
import pandas as pd
from cognlp import DataTable


class CsvRelationLoader(Loader):

    def __init__(self):
        super().__init__()

    def _load(self, path):

        dataset = DataTable()
        data = pd.read_csv(path)
        for i in range(len(data)):
            dataset('sentence', data.loc[i]['sentence'])
            dataset('relation', data.loc[i]['relation'])
            self.label_set.add(data.loc[i]['relation'])
            dataset('head', data.loc[i]['head'])
            dataset('tail', data.loc[i]['tail'])
        return dataset

    def load_all(self, path):
        train_path = os.path.join(path, 'train.csv')
        dev_path = os.path.join(path, 'valid.csv')
        test_path = os.path.join(path, 'test.csv')
        return self._load(train_path), self._load(dev_path), self._load(test_path)
